﻿using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.Xna.Framework;
using Microsoft.Xna.Framework.Graphics;

namespace Utils
{
    public static class PickingHelper
    {
        /// <summary>
        /// Checks whether a ray intersects a model. This method needs to access
        /// the model vertex data, so the model must have been built using the
        /// custom TrianglePickingProcessor provided as part of this sample.
        /// Returns the distance along the ray to the point of intersection, or null
        /// if there is no intersection.
        /// </summary>
        public static float? RayIntersectsModel(Ray ray, Model model, Matrix modelTransform,
                                         out bool insideBoundingSphere,
                                         out Vector3 vertex1, out Vector3 vertex2,
                                         out Vector3 vertex3)
        {
            vertex1 = vertex2 = vertex3 = Vector3.Zero;

            // The input ray is in world space, but our model data is stored in object
            // space. We would normally have to transform all the model data by the
            // modelTransform matrix, moving it into world space before we test it
            // against the ray. That transform can be slow if there are a lot of
            // triangles in the model, however, so instead we do the opposite.
            // Transforming our ray by the inverse modelTransform moves it into object
            // space, where we can test it directly against our model data. Since there
            // is only one ray but typically many triangles, doing things this way
            // around can be much faster.


            Matrix inverseTransform = Matrix.Invert(modelTransform);

            ray.Position = Vector3.Transform(ray.Position, inverseTransform);
            ray.Direction = Vector3.TransformNormal(ray.Direction, inverseTransform);

            // Look up our custom collision data from the Tag property of the model.
            var tagData = (Dictionary<string, object>)model.Tag;

            if (tagData == null)
            {
                throw new InvalidOperationException(
                    "Model.Tag is not set correctly. Make sure your model " +
                    "was built using the custom TrianglePickingProcessor.");
            }

            // Start off with a fast bounding sphere test.
            var boundingSphere = (BoundingSphere)tagData["BoundingSphere"];

            if (boundingSphere.Intersects(ray) == null)
            {
                // If the ray does not intersect the bounding sphere, we cannot
                // possibly have picked this model, so there is no need to even
                // bother looking at the individual triangle data.
                insideBoundingSphere = false;

                return null;
            }
            
            // The bounding sphere test passed, so we need to do a full
            // triangle picking test.
            insideBoundingSphere = true;

            // Keep track of the closest triangle we found so far,
            // so we can always return the closest one.
            float? closestIntersection = null;

            // Loop over the vertex data, 3 at a time (3 vertices = 1 triangle).
            var vertices = (Vector3[])tagData["Vertices"];

            for (var i = 0; i < vertices.Length; i += 3)
            {
                // Perform a ray to triangle intersection test.
                float? intersection = RayIntersectsTriangle(ref ray,
                                      ref vertices[i],
                                      ref vertices[i + 1],
                                      ref vertices[i + 2]);

                // Does the ray intersect this triangle?
                if (intersection == null) 
                    continue;

                // If so, is it closer than any other previous triangle?
                if ((closestIntersection != null) && (intersection >= closestIntersection)) 
                    continue;


                // Store the distance to this triangle.
                closestIntersection = intersection;

                // Transform the three vertex positions into world space,
                // and store them into the output vertex parameters.
                Vector3.Transform(ref vertices[i],
                                  ref modelTransform, out vertex1);

                Vector3.Transform(ref vertices[i + 1],
                                  ref modelTransform, out vertex2);

                Vector3.Transform(ref vertices[i + 2],
                                  ref modelTransform, out vertex3);
            }

            return closestIntersection;
          
        }


        /// <summary>
        /// Checks whether a ray intersects a triangle. This uses the algorithm
        /// developed by Tomas Moller and Ben Trumbore, which was published in the
        /// Journal of Graphics Tools, volume 2, "Fast, Minimum Storage Ray-Triangle
        /// Intersection".
        /// 
        /// This method is implemented using the pass-by-reference versions of the
        /// XNA math functions. Using these overloads is generally not recommended,
        /// because they make the code less readable than the normal pass-by-value
        /// versions. This method can be called very frequently in a tight inner loop,
        /// however, so in this particular case the performance benefits from passing
        /// everything by reference outweigh the loss of readability.
        /// </summary>
        public static float? RayIntersectsTriangle(ref Ray ray, ref Vector3 vertex1, ref Vector3 vertex2, ref Vector3 vertex3)
        {
            // Compute vectors along two edges of the triangle.
            Vector3 edge1, edge2;

            Vector3.Subtract(ref vertex2, ref vertex1, out edge1);
            Vector3.Subtract(ref vertex3, ref vertex1, out edge2);

            // Compute the determinant.
            Vector3 directionCrossEdge2;
            Vector3.Cross(ref ray.Direction, ref edge2, out directionCrossEdge2);

            float determinant;
            Vector3.Dot(ref edge1, ref directionCrossEdge2, out determinant);

            // If the ray is parallel to the triangle plane, there is no collision.
            if (determinant > -float.Epsilon && determinant < float.Epsilon)
            {
                return null;
            }

            float inverseDeterminant = 1.0f / determinant;

            // Calculate the U parameter of the intersection point.
            Vector3 distanceVector;
            Vector3.Subtract(ref ray.Position, ref vertex1, out distanceVector);

            float triangleU;
            Vector3.Dot(ref distanceVector, ref directionCrossEdge2, out triangleU);
            triangleU *= inverseDeterminant;

            // Make sure it is inside the triangle.
            if (triangleU < 0 || triangleU > 1)
            {
                return null;
            }

            // Calculate the V parameter of the intersection point.
            Vector3 distanceCrossEdge1;
            Vector3.Cross(ref distanceVector, ref edge1, out distanceCrossEdge1);

            float triangleV;
            Vector3.Dot(ref ray.Direction, ref distanceCrossEdge1, out triangleV);
            triangleV *= inverseDeterminant;

            // Make sure it is inside the triangle.
            if (triangleV < 0 || triangleU + triangleV > 1)
            {
                return null;
            }

            // Compute the distance along the ray to the triangle.
            float rayDistance;
            Vector3.Dot(ref edge2, ref distanceCrossEdge1, out rayDistance);
            rayDistance *= inverseDeterminant;

            // Is the triangle behind the ray origin?
            if (rayDistance < 0)
            {
                return null;
            }

            return rayDistance;
        }

        ///// <summary>
        ///// This helper function takes a BoundingSphere and a transform matrix, and
        ///// returns a transformed version of that BoundingSphere.
        ///// </summary>
        ///// <param name="sphere">the BoundingSphere to transform</param>
        ///// <param name="world">how to transform the BoundingSphere.</param>
        ///// <returns>the transformed BoundingSphere/</returns>
        //private static BoundingSphere TransformBoundingSphere(BoundingSphere sphere, Matrix transform)
        //{
        //    BoundingSphere transformedSphere;

        //    // the transform can contain different scales on the x, y, and z components.
        //    // this has the effect of stretching and squishing our bounding sphere along
        //    // different axes. Obviously, this is no good: a bounding sphere has to be a
        //    // SPHERE. so, the transformed sphere's radius must be the maximum of the 
        //    // scaled x, y, and z radii.

        //    // to calculate how the transform matrix will affect the x, y, and z
        //    // components of the sphere, we'll create a vector3 with x y and z equal
        //    // to the sphere's radius...
        //    Vector3 scale3 = new Vector3(sphere.Radius, sphere.Radius, sphere.Radius);

        //    // then transform that vector using the transform matrix. we use
        //    // TransformNormal because we don't want to take translation into account.
        //    scale3 = Vector3.TransformNormal(scale3, transform);

        //    // scale3 contains the x, y, and z radii of a squished and stretched sphere.
        //    // we'll set the finished sphere's radius to the maximum of the x y and z
        //    // radii, creating a sphere that is large enough to contain the original 
        //    // squished sphere.
        //    transformedSphere.Radius = System.Math.Max(scale3.X, System.Math.Max(scale3.Y, scale3.Z));

        //    // transforming the center of the sphere is much easier. we can just use 
        //    // Vector3.Transform to transform the center vector. notice that we're using
        //    // Transform instead of TransformNormal because in this case we DO want to 
        //    // take translation into account.
        //    transformedSphere.Center = Vector3.Transform(sphere.Center, transform);


        //    return transformedSphere;
        //}
    }
}
